{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bIzs9AIWgobW"
      },
      "source": [
        "# Hyperparameter Ensembles for Robustness and Uncertainty Quantification\n",
        "\n",
        "*Florian Wenzel, April 8th 2021. Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "Recently, we proposed **Hyper-deep Ensembles** ([Wenzel et al., NeurIPS 2020](https://arxiv.org/abs/2006.13570)) a simple, yet powerful, extension of [deep ensembles](https://arxiv.org/abs/1612.01474). The approach works with any given deep network architecture and, therefore, can be easily integrated (and improve) a machine learning system that is already used in production.\n",
        "\n",
        "Hyper-deep ensembles improve the performance of a given deep network by forming an ensemble over multiple variants of that architecture where each member uses different hyperparameters. In this notebook we consider a ResNet-20 architecture with block-wise $\\ell_2$-regularization parameters and a label smoothing parameter. We construct an ensemble of 4 members where each member uses a  different set of hyperparameters. This leads to an ensemble of **diverse members**, i.e., members that are complementary in their predictions. The final ensemble greatly improves the prediction performance and the robustness of the model, e.g., in out-of-distribution settings.\n",
        "\n",
        "Let's start with some boilerplate code for data loading and the model definition.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "01aweUS5Ny9b"
      },
      "source": [
        "Requirements:\n",
        "```bash\n",
        "!pip install \"git+https://github.com/google/uncertainty-baselines.git#egg=uncertainty_baselines\"\n",
        "```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v3eFRRs-Qnxp"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import numpy as np\n",
        "\n",
        "import uncertainty_baselines as ub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mv2-1HmBJ4QK"
      },
      "outputs": [],
      "source": [
        "def _ensemble_accuracy(labels, logits_list):\n",
        "  \"\"\"Compute the accuracy resulting from the ensemble prediction.\"\"\"\n",
        "  per_probs = tf.nn.softmax(logits_list)\n",
        "  probs = tf.reduce_mean(per_probs, axis=0)\n",
        "  acc = tf.keras.metrics.SparseCategoricalAccuracy()\n",
        "  acc.update_state(labels, probs)\n",
        "  return acc.result()\n",
        "\n",
        "def _ensemble_cross_entropy(labels, logits):\n",
        "  logits = tf.convert_to_tensor(logits)\n",
        "  ensemble_size = float(logits.shape[0])\n",
        "  labels = tf.cast(labels, tf.int32)\n",
        "  ce = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
        "      labels=tf.broadcast_to(labels[tf.newaxis, ...], tf.shape(logits)[:-1]),\n",
        "      logits=logits)\n",
        "  nll = -tf.reduce_logsumexp(-ce, axis=0) + tf.math.log(ensemble_size)\n",
        "  return tf.reduce_mean(nll)\n",
        "\n",
        "\n",
        "def greedy_selection(val_logits, val_labels, max_ens_size, objective='nll'):\n",
        "  \"\"\"Greedy procedure from Caruana et al. 2004, with replacement.\"\"\"\n",
        "\n",
        "  assert_msg = 'Unknown objective type (received {}).'.format(objective)\n",
        "  assert objective in ('nll', 'acc', 'nll-acc'), assert_msg\n",
        "\n",
        "  # Objective that should be optimized by the ensemble. Arbitrary objectives,\n",
        "  # e.g., based on nll, acc or calibration error (or combinations of those) can\n",
        "  # be used.\n",
        "  if objective == 'nll':\n",
        "    get_objective = lambda acc, nll: nll\n",
        "  elif objective == 'acc':\n",
        "    get_objective = lambda acc, nll: acc\n",
        "  else:\n",
        "    get_objective = lambda acc, nll: nll-acc\n",
        "\n",
        "  best_acc = 0.\n",
        "  best_nll = np.inf\n",
        "  best_objective = np.inf\n",
        "  ens = []\n",
        "\n",
        "  def get_ens_size():\n",
        "    return len(set(ens))\n",
        "\n",
        "  while get_ens_size() \u003c max_ens_size:\n",
        "    current_val_logits = [val_logits[model_id] for model_id in ens]\n",
        "    best_model_id = None\n",
        "    for model_id, logits in enumerate(val_logits):\n",
        "      acc = _ensemble_accuracy(val_labels, current_val_logits + [logits])\n",
        "      nll = _ensemble_cross_entropy(val_labels, current_val_logits + [logits])\n",
        "      obj = get_objective(acc, nll)\n",
        "      if obj \u003c best_objective:\n",
        "        best_acc = acc\n",
        "        best_nll = nll\n",
        "        best_objective = obj\n",
        "        best_model_id = model_id\n",
        "    if best_model_id is None:\n",
        "      print('Ensemble could not be improved: Greedy selection stops.')\n",
        "      break\n",
        "    ens.append(best_model_id)\n",
        "  return ens, best_acc, best_nll\n",
        "\n",
        "\n",
        "def parse_checkpoint_dir(checkpoint_dir):\n",
        "  \"\"\"Parse directory of checkpoints.\"\"\"\n",
        "  paths = []\n",
        "  subdirectories = tf.io.gfile.glob(os.path.join(checkpoint_dir, '*'))\n",
        "  is_checkpoint = lambda f: ('checkpoint' in f and '.index' in f)\n",
        "  print('Load checkpoints')\n",
        "  for subdir in subdirectories:\n",
        "    for path, _, files in tf.io.gfile.walk(subdir):\n",
        "      if any(f for f in files if is_checkpoint(f)):\n",
        "        latest_checkpoint = tf.train.latest_checkpoint(path)\n",
        "        paths.append(latest_checkpoint)\n",
        "        print('.', end='')\n",
        "        break\n",
        "  print('')\n",
        "  return paths"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hcEdoRDoCwvZ"
      },
      "outputs": [],
      "source": [
        "DATASET = 'cifar10'\n",
        "TRAIN_PROPORTION = 0.95\n",
        "BATCH_SIZE = 64\n",
        "ENSEMBLE_SIZE = 4\n",
        "CHECKPOINT_DIR = 'gs://gresearch/reliable-deep-learning/checkpoints/baselines/cifar/hyper_ensemble/'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QEuYsfW0_A6V"
      },
      "outputs": [],
      "source": [
        "# Load data.\n",
        "ds_info = tfds.builder(DATASET).info\n",
        "num_classes = ds_info.features['label'].num_classes\n",
        "# Test set.\n",
        "steps_per_eval = ds_info.splits['test'].num_examples // BATCH_SIZE \n",
        "test_dataset = ub.datasets.get(\n",
        "      DATASET,\n",
        "      split=tfds.Split.TEST).load(batch_size=BATCH_SIZE)\n",
        "# Validation set.\n",
        "validation_percent = 1 - TRAIN_PROPORTION\n",
        "val_dataset = ub.datasets.get(\n",
        "    dataset_name=DATASET,\n",
        "    split=tfds.Split.VALIDATION,\n",
        "    validation_percent=validation_percent,\n",
        "    drop_remainder=False).load(batch_size=BATCH_SIZE)\n",
        "steps_per_val_eval = int(ds_info.splits['train'].num_examples *\n",
        "                          validation_percent) // BATCH_SIZE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gZZO_cnVp5qC"
      },
      "source": [
        "# Let's construct the hyper-deep ensemble over a ResNet-20 architecture\n",
        "\n",
        "\n",
        "**This is the (simplified) hyper-deep ensembles construction pipeline**\n",
        "\u003e **1. Random search:** train several models on the train set using different (random) hyperparameters.\n",
        "\u003e \n",
        "\u003e **2. Ensemble construction:** on a validation set using a greedy selection method.\n",
        "\n",
        "Remark:\n",
        "*In this notebook we use a slightly simplified version of the pipeline compared to the approach of the original paper (where an additional stratification step is used). Additionally, after selecting the optimal hyperparameters the ensemble performance can be improved even more by retraining the selected models on the full train set (i.e., this time not reserving a portion for validation). The simplified pipeline in this notebook is slightly less performant but easier to implement. The simplified pipeline is similar to the ones used by Caranua et al., 2004 and Zaidi et al., 2020 in the context of neural architecture search.*\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9IXvemGwBKk-"
      },
      "source": [
        "## Step 1: Random Hyperparameter Search\n",
        "\n",
        "We start by training 100 different versions of the ResNet-20 using different $\\ell_2$-regularization parameters and label smoothing parameters. Since this would take some time we have already trained the models using a standard training script (which can be found [here](https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/deterministic.py)) and directly load the checkpoints (which can be browsed [here](https://console.cloud.google.com/storage/browser/gresearch/reliable-deep-learning/checkpoints/baselines/cifar/hyper_ensemble/)).\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pii1xJlIEWSp"
      },
      "outputs": [],
      "source": [
        "# The model architecture we want to form the ensemble over\n",
        "# here, we use the original ResNet-20 model by He et al. 2015.\n",
        "model = ub.models.wide_resnet(\n",
        "    input_shape=ds_info.features['image'].shape,\n",
        "    depth=22,\n",
        "    width_multiplier=1,\n",
        "    num_classes=num_classes,\n",
        "    l2=0.,\n",
        "    version=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 56,
          "status": "ok",
          "timestamp": 1626166428808,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "GhRSwqKyBr4S",
        "outputId": "c2d8626f-bc95-4515-84bc-7cf9a95ab97a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model pool size: 100\n"
          ]
        }
      ],
      "source": [
        "# Load checkpoints:\n",
        "# These are 100 checkpoints and loading will take a few minutes.\n",
        "ensemble_filenames = parse_checkpoint_dir(CHECKPOINT_DIR)\n",
        "model_pool_size = len(ensemble_filenames)\n",
        "checkpoint = tf.train.Checkpoint(model=model)\n",
        "print('Model pool size: {}'.format(model_pool_size))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYUhBpODKd2v"
      },
      "source": [
        "## Step 2: Construction of the hyperparameter ensemble on the validation set\n",
        "\n",
        "First we compute the logits of all models in our model pool on the validation set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 488509,
          "status": "ok",
          "timestamp": 1626167063340,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "y4zJ1ZzFKOQB",
        "outputId": "5abcbde1-5a7b-4032-917c-2f54b7b99cc6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1.0% completion for prediction on validation set: model 1/100.\n",
            "11.0% completion for prediction on validation set: model 11/100.\n",
            "21.0% completion for prediction on validation set: model 21/100.\n",
            "31.0% completion for prediction on validation set: model 31/100.\n",
            "41.0% completion for prediction on validation set: model 41/100.\n",
            "51.0% completion for prediction on validation set: model 51/100.\n",
            "61.0% completion for prediction on validation set: model 61/100.\n",
            "71.0% completion for prediction on validation set: model 71/100.\n",
            "81.0% completion for prediction on validation set: model 81/100.\n",
            "91.0% completion for prediction on validation set: model 91/100.\n",
            "100.0% completion for prediction on validation set: model 100/100.\n"
          ]
        }
      ],
      "source": [
        "# Compute the logits on the validation set.\n",
        "val_logits, val_labels = [], []\n",
        "for m, ensemble_filename in enumerate(ensemble_filenames):\n",
        "  # Enforce memory clean-up.\n",
        "  tf.keras.backend.clear_session()\n",
        "  checkpoint.restore(ensemble_filename)\n",
        "  val_iterator = iter(val_dataset)\n",
        "  val_logits_m = []\n",
        "  for _ in range(steps_per_val_eval):\n",
        "    inputs = next(val_iterator)\n",
        "    features = inputs['features']\n",
        "    labels = inputs['labels']\n",
        "    val_logits_m.append(model(features, training=False))\n",
        "    if m == 0:\n",
        "      val_labels.append(labels)\n",
        "\n",
        "  val_logits.append(tf.concat(val_logits_m, axis=0))\n",
        "  if m == 0:\n",
        "    val_labels = tf.concat(val_labels, axis=0)\n",
        "\n",
        "  if m % 10 == 0 or m == model_pool_size - 1:\n",
        "    percent = (m + 1.) / model_pool_size\n",
        "    message = ('{:.1%} completion for prediction on validation set: '\n",
        "                'model {:d}/{:d}.'.format(percent, m + 1, model_pool_size))\n",
        "    print(message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mk6fdgeosFYG"
      },
      "source": [
        "Now we are ready to construct the ensemble.\n",
        "* In the first step, we take the best model (on the validation set) -\u003e `model_1`.\n",
        "* In the second step, we fix `model_1` and try all models in our model pool and construct the ensemble `[model_1, model_2]`. We select the model `model_2` that leads to the highest performance gain.\n",
        "* In the third step, we fix `model_1`, `model_2` and choose `model_3` to construct an ensemble `[model_1, model_2, model_3]` that leads to the highest performance gain over step 2.\n",
        "* ... and so on, until the desired ensemble size is reached or no performance gain could be achieved anymore."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 1677,
          "status": "ok",
          "timestamp": 1626167306386,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "8fU0FT_QLLnu",
        "outputId": "732dfdec-6a26-4c86-8327-9d52ce6c90f4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Members selected by greedy procedure: model ids = [38, 0, 81, 27] (with 4 unique member(s)).\n"
          ]
        }
      ],
      "source": [
        "# Ensemble construction by greedy member selection on the validation set.\n",
        "selected_members, val_acc, val_nll = greedy_selection(val_logits, val_labels,\n",
        "                                                        ENSEMBLE_SIZE,\n",
        "                                                        objective='nll')\n",
        "unique_selected_members = list(set(selected_members))\n",
        "message = ('Members selected by greedy procedure: model ids = {} (with {} '\n",
        "            'unique member(s)).').format(\n",
        "                selected_members, len(unique_selected_members))\n",
        "print(message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lhUtHjMituAn"
      },
      "source": [
        "# Evaluation on the test set\n",
        "\n",
        "Let's see how the **hyper-deep ensemble** performs on the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 25716,
          "status": "ok",
          "timestamp": 1626167386820,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "0Nk6-zOdLuH7",
        "outputId": "e3e81d71-4f14-4dc6-b360-5e3bd9681517"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Completed computation of member logits on the test set.\n"
          ]
        }
      ],
      "source": [
        "# Evaluate the following metrics on the test set.\n",
        "metrics = {\n",
        "    'ensemble/negative_log_likelihood': tf.keras.metrics.Mean(),\n",
        "    'ensemble/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),\n",
        "}\n",
        "metrics_single = {\n",
        "    'single/negative_log_likelihood': tf.keras.metrics.SparseCategoricalCrossentropy(),\n",
        "    'single/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),\n",
        "}\n",
        "\n",
        "\n",
        "# Compute logits for each ensemble member on the test set.\n",
        "logits_test = []\n",
        "for m, member_id in enumerate(unique_selected_members):\n",
        "  ensemble_filename = ensemble_filenames[member_id]\n",
        "  checkpoint.restore(ensemble_filename)\n",
        "  logits = []\n",
        "  test_iterator = iter(test_dataset)\n",
        "  for _ in range(steps_per_eval):\n",
        "    features = next(test_iterator)['features']\n",
        "    logits.append(model(features, training=False))\n",
        "  logits_test.append(tf.concat(logits, axis=0))\n",
        "logits_test = tf.convert_to_tensor(logits_test)\n",
        "print('Completed computation of member logits on the test set.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 769,
          "status": "ok",
          "timestamp": 1626167387703,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "JyU2A-uCu-Us",
        "outputId": "5d34b1b9-a62c-47e8-c143-36974441930a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.6% completion final test prediction\n",
            "16.7% completion final test prediction\n",
            "32.7% completion final test prediction\n",
            "48.7% completion final test prediction\n",
            "64.7% completion final test prediction\n",
            "80.8% completion final test prediction\n",
            "96.8% completion final test prediction\n",
            "100.0% completion final test prediction\n"
          ]
        }
      ],
      "source": [
        "# Compute test metrics.\n",
        "test_iterator = iter(test_dataset)\n",
        "for step in range(steps_per_eval):\n",
        "  labels = next(test_iterator)['labels']\n",
        "  logits = logits_test[:, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]\n",
        "  labels = tf.cast(labels, tf.int32)\n",
        "  negative_log_likelihood = _ensemble_cross_entropy(labels, logits)\n",
        "  # Per member output probabilities.\n",
        "  per_probs = tf.nn.softmax(logits)\n",
        "  # Ensemble output probabilites.\n",
        "  probs = tf.reduce_mean(per_probs, axis=0)\n",
        "  metrics['ensemble/negative_log_likelihood'].update_state(\n",
        "      negative_log_likelihood)\n",
        "  metrics['ensemble/accuracy'].update_state(labels, probs)\n",
        "\n",
        "  # For comparison compute performance of the best single model,\n",
        "  # this is by definition the first model that was selected by the greedy \n",
        "  # selection method.\n",
        "  logits_single = logits_test[0, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]\n",
        "  probs_single = tf.nn.softmax(logits_single)\n",
        "  metrics_single['single/negative_log_likelihood'].update_state(labels, logits_single)\n",
        "  metrics_single['single/accuracy'].update_state(labels, probs_single)\n",
        "\n",
        "  percent = (step + 1) / steps_per_eval\n",
        "  if step % 25 == 0 or step == steps_per_eval - 1:\n",
        "    message = ('{:.1%} completion final test prediction'.format(percent))\n",
        "    print(message)\n",
        "\n",
        "ensemble_results = {name: metric.result() for name, metric in metrics.items()}\n",
        "single_results = {name: metric.result() for name, metric in metrics_single.items()}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w3wkHMyfuyFO"
      },
      "source": [
        "## Here is the final ensemble performance\n",
        "\n",
        "We gained almost 2 percentage points in terms of accuracy over the best single model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 55,
          "status": "ok",
          "timestamp": 1626167392466,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "3k3XGtpebXv3",
        "outputId": "87016d0a-6c34-4ac1-843a-ed0519fd1b63"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Ensemble performance:\n",
            "   ensemble/negative_log_likelihood: 0.19807305932044983\n",
            "   ensemble/accuracy: 0.9358974099159241\n",
            "\n",
            "For comparison:\n",
            "   single/negative_log_likelihood: 1.0325815677642822\n",
            "   single/accuracy: 0.9189703464508057\n"
          ]
        }
      ],
      "source": [
        "print('Ensemble performance:')\n",
        "for m, val in ensemble_results.items():\n",
        "  print('   {}: {}'.format(m, val))\n",
        "\n",
        "print('\\nFor comparison:')\n",
        "for m, val in single_results.items():\n",
        "  print('   {}: {}'.format(m, val))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IJiXCSmpy1W4"
      },
      "source": [
        "## Hyper-deep ensembles as a strong baseline\n",
        "\n",
        "We have seen that **hyper-deep ensembles** can lead to significant performance gains and can be easily implemented in your existing machine learning pipeline. Moreover, we hope that other researchers can benefit from this by using **hyper-deep ensembles** as a competitive, yet simple-to-implement, baseline. Even though **hyper-deep ensembles** might be more expensive than single model methods, it can show how much can be gained by introducing more diversity in the predictions.\n",
        "\n",
        "\n",
        "## Hyper-deep ensembles can make your ML pipeline more robust\n",
        "\n",
        "**Don't throw away your precious models!**\n",
        "\n",
        "In many settings where we use a standard (single model) deep neural network, we usually start with a hyperparameter search. Typically, we select the model with the best hyperparameters and throw away all the others. Here, we show that you can get a much more performant system by combining multiple models from the hyperparameter search. \n",
        "\n",
        "**What's the additional cost?**\n",
        "\n",
        "In most cases you already get a significant performance boost if you combine 4 models. The main additional cost (provided you have already done the hyperparameter search) is that your model is now 4x larger (more memory) and 4x times slower to perform the predictions (if not parallelized). Often the performance boost justifies this increased cost. If you can't afford the additional cost, check out **hyper-batch ensembles**. This is an efficient version that amortizes hyper-deep ensembles **within a single model** (see our [paper](https://arxiv.org/abs/2006.13570)).\n",
        "\n",
        "## Pointers to additional resources\n",
        "\n",
        "* The full code for the extended **hyper-deep ensembles** pipeline and the code for the experiments in our paper can be found in the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines/blob/main/baselines/cifar/hyperdeepensemble.py) repository.\n",
        "* Our efficient version **hyper-batch ensembles** that amortize hyper-deep ensembles within a single model is implemented as a  keras layer and can be found in [Edward2](https://github.com/google/edward2 ).\n",
        "\n",
        "## For questions reach out to\n",
        "Florian Wenzel ([florianwenzel@google.com](mailto:florianwenzel@google.com)) \\\n",
        "Rodolphe Jenatton ([rjenatton@google.com](mailto:rjenatton@google.com))\n",
        "\n",
        "\n",
        "### Reference\n",
        "\n",
        "If you use parts of this pipeline we would be happy if you would cite our paper.\n",
        "\n",
        "\u003e Florian Wenzel, Jasper Snoek, Dustin Tran and Rodolphe Jenatton (2020).\n",
        "\u003e [Hyperparameter Ensembles for Robustness and Uncertainty Quantification](https://arxiv.org/abs/2006.13570).\n",
        "\u003e In _Neural Information Processing Systems_.\n",
        "\n",
        "```none\n",
        "@inproceedings{wenzel2020good,\n",
        "  author = {Florian Wenzel and Jasper Snoek and Dustin Tran and Rodolphe Jenatton},\n",
        "  title = {Hyperparameter Ensembles for Robustness and Uncertainty Quantification},\n",
        "  booktitle = {Neural Information Processing Systems)},\n",
        "  year = {2020},\n",
        "}\n",
        "```\n",
        "\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "Hyperparemeter Ensembles.ipynb",
      "provenance": [
        {
          "file_id": "11WX0nHlYG8owHzEG3eiYYRqk1z4qcvl0",
          "timestamp": 1617806653259
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
